{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST Visualization Example\n",
    "\n",
    "Real-time visualization of MNIST training on a CNN, using TensorFlow and [TensorDebugger](https://github.com/ericjang/tdb)\n",
    "\n",
    "The visualizations in this notebook won't show up on http://nbviewer.ipython.org. To view the widgets and interact with them, you will need to download this notebook and run it with a Jupyter Notebook server."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Load TDB Notebook Extension"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "Jupyter.utils.load_extensions('tdb_ext/main')"
      ],
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%javascript\n",
    "Jupyter.utils.load_extensions('tdb_ext/main')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "#import sys\n",
    "#sys.path.append('/home/evjang/thesis/tensor_debugger')\n",
    "import tdb\n",
    "from tdb.examples import mnist, viz\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import urllib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Build TensorFlow Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "(train_data_node,\n",
    "    train_labels_node,\n",
    "    validation_data_node,\n",
    "    test_data_node,\n",
    "    # predictions\n",
    "    train_prediction,\n",
    "    validation_prediction,\n",
    "    test_prediction,\n",
    "    # weights\n",
    "    conv1_weights,\n",
    "    conv2_weights,\n",
    "    fc1_weights,\n",
    "    fc2_weights,\n",
    "    # training\n",
    "    optimizer,\n",
    "    loss,\n",
    "    learning_rate,\n",
    "    summaries) = mnist.build_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Attach Plotting Ops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def viz_activations(ctx, m):\n",
    "    plt.matshow(m.T,cmap=plt.cm.gray)\n",
    "    plt.title(\"LeNet Predictions\")\n",
    "    plt.xlabel(\"Batch\")\n",
    "    plt.ylabel(\"Digit Activation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# plotting a user-defined function 'viz_activations'\n",
    "p0=tdb.plot_op(viz_activations,inputs=[train_prediction])\n",
    "# weight variables are of type tf.Variable, so we need to find the corresponding tf.Tensor instead\n",
    "g=tf.get_default_graph()\n",
    "p1=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv1_weights)])\n",
    "p2=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv2_weights)])\n",
    "p3=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc1_weights)])\n",
    "p4=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc2_weights)])\n",
    "p2=tdb.plot_op(viz.viz_conv_hist,inputs=[g.as_graph_element(conv1_weights)])\n",
    "ploss=tdb.plot_op(viz.watch_loss,inputs=[loss])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Download the MNIST dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train-images-idx3-ubyte.gz\n",
      "train-labels-idx1-ubyte.gz\n",
      "t10k-images-idx3-ubyte.gz\n",
      "t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "base_url='http://yann.lecun.com/exdb/mnist/'\n",
    "files=['train-images-idx3-ubyte.gz',\n",
    " 'train-labels-idx1-ubyte.gz',\n",
    " 't10k-images-idx3-ubyte.gz',\n",
    " 't10k-labels-idx1-ubyte.gz']\n",
    "download_dir='/tmp/'\n",
    "for f in files:\n",
    "    print(f)\n",
    "    urllib.urlretrieve(base_url+f, download_dir+f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Debug + Visualize!\n",
    "\n",
    "Upon evaluating plot nodes p1,p2,p3,p4,ploss, plots will be generated in the Plot view on the right."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('Extracting', '/tmp/train-images-idx3-ubyte.gz')\n",
      "('Extracting', '/tmp/train-labels-idx1-ubyte.gz')\n",
      "('Extracting', '/tmp/t10k-images-idx3-ubyte.gz')\n",
      "('Extracting', '/tmp/t10k-labels-idx1-ubyte.gz')\n"
     ]
    }
   ],
   "source": [
    "# return the TF nodes corresponding to graph input placeholders\n",
    "(train_data, \n",
    " train_labels, \n",
    " validation_data, \n",
    " validation_labels, \n",
    " test_data, \n",
    " test_labels) = mnist.get_data(download_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# start the TensorFlow session that will be used to evaluate the graph\n",
    "s=tf.InteractiveSession()\n",
    "tf.initialize_all_variables().run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 29.668428\n",
      "loss: 15.983353\n",
      "loss: 11.249242\n",
      "loss: 10.028939\n",
      "loss: 8.065391\n",
      "loss: 9.335689\n",
      "loss: 7.316875\n",
      "loss: 8.376289\n",
      "loss: 7.735221\n",
      "loss: 8.383675\n",
      "loss: 5.704120\n",
      "loss: 6.037778\n",
      "loss: 7.309663\n",
      "loss: 7.349874\n",
      "loss: 7.528041\n",
      "loss: 8.209503\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-9-ccc202b40680>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     12\u001b[0m     }\n\u001b[0;32m     13\u001b[0m     \u001b[1;31m# run training node and visualization node\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 14\u001b[1;33m     \u001b[0mstatus\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtdb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0ms\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     15\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mstep\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;36m10\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m         \u001b[0mstatus\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtdb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdebug\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp1\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp3\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mp4\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mploss\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbreakpoints\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbreak_immediately\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0ms\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/interface.pyc\u001b[0m in \u001b[0;36mdebug\u001b[1;34m(evals, feed_dict, breakpoints, break_immediately, session)\u001b[0m\n\u001b[0;32m     15\u001b[0m         \u001b[1;32mglobal\u001b[0m \u001b[0m_dbsession\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     16\u001b[0m         \u001b[0m_dbsession\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdebug_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mDebugSession\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 17\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0m_dbsession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mevals\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mbreakpoints\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mbreak_immediately\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     19\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, evals, feed_dict, breakpoints, break_immediately)\u001b[0m\n\u001b[0;32m     59\u001b[0m                         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_break\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     60\u001b[0m                 \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 61\u001b[1;33m                         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     62\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     63\u001b[0m         \u001b[1;32mdef\u001b[0m \u001b[0ms\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc\u001b[0m in \u001b[0;36mc\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     85\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     86\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstate\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mRUNNING\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_eval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     88\u001b[0m                 \u001b[1;31m# increment to next node\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     89\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc\u001b[0m in \u001b[0;36m_eval\u001b[1;34m(self, node)\u001b[0m\n\u001b[0;32m    168\u001b[0m                 \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# is a TensorFlow node\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    169\u001b[0m                         \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 170\u001b[1;33m                                 \u001b[0mresult\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cache\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    171\u001b[0m                                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_cache\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnode\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    172\u001b[0m                         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict)\u001b[0m\n\u001b[0;32m    346\u001b[0m     \u001b[1;31m# Run request and get response.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    347\u001b[0m     \u001b[1;31m#pdb.set_trace()\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 348\u001b[1;33m     \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_run\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0munique_fetch_targets\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict_string\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    349\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    350\u001b[0m     \u001b[1;31m# User may have fetched the same tensor multiple times, but we\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, target_list, fetch_list, feed_dict)\u001b[0m\n\u001b[0;32m    405\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    406\u001b[0m       return tf_session.TF_Run(self._session, feed_dict, fetch_list,\n\u001b[1;32m--> 407\u001b[1;33m                                target_list)\n\u001b[0m\u001b[0;32m    408\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    409\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mStatusNotOK\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "BATCH_SIZE = 64\n",
    "NUM_EPOCHS = 5\n",
    "TRAIN_SIZE=10000\n",
    "\n",
    "for step in xrange(NUM_EPOCHS * TRAIN_SIZE // BATCH_SIZE):\n",
    "    offset = (step * BATCH_SIZE) % (TRAIN_SIZE - BATCH_SIZE)\n",
    "    batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]\n",
    "    batch_labels = train_labels[offset:(offset + BATCH_SIZE)]\n",
    "    feed_dict = {\n",
    "        train_data_node: batch_data,\n",
    "        train_labels_node: batch_labels\n",
    "    }\n",
    "    # run training node and visualization node\n",
    "    status,result=tdb.debug([optimizer,p0], feed_dict=feed_dict, session=s)\n",
    "    if step % 10 == 0:  \n",
    "        status,result=tdb.debug([loss,p1,p2,p3,p4,ploss], feed_dict=feed_dict, breakpoints=None, break_immediately=False, session=s)\n",
    "        print('loss: %f' % (result[0]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
